{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference using Bayesian Networks\n",
    "\n",
    "Inference is a powerful technique that combines empirical evidence and prior statistical beliefs to determine likelihood from causal connections. A classic example is presented below in graphical form (taken from [this blog](https://towardsdatascience.com/introduction-to-bayesian-networks-81031eeed94e); please excuse the convention of substituting commas for decimal places):\n",
    "\n",
    "![wet grass example](https://miro.medium.com/max/1385/1*9OsQV0PqM2juaOtGqoRISw.jpeg)\n",
    "\n",
    "As shown above, there are two causal explanations for why the grass is wet-- the sprinkler could be on, or it could have rained. We can enter the graph at any point to infer the conditional likelihood of each variable; for example, if we know the state of *cloudy*, we can replace the 50/50 prior probability, and propagate that through the graph. We can also reverse the direction; if we know that the grass is wet, we can infer the empirical likelihood of *cloudy* being true such that it is weighted by this evidence, instead of just the prior belief. Note that this means that it is possible to determine prior beliefs from data, by choosing an naive prior and replacing after acquiring empirical data.\n",
    "\n",
    "The classical example above maps to ICESat-2 data with some simple modifications. Let's sat that you observe an ICESat-2 QAQC flag that the data is \"not useful\", caused by the ATL06 algorithm failing to find the ground. There are several possible causal reasons for this: the surface could be obscured by optically thick clouds; the surface could be complex because of crevassing; there could blowing snow on the surface; there could be an instrument issue. If we discount the last option of instrumentation error, we can pull out useful information from the QAQC flag about the physical process that is causing the flag. If we have cloud data (i.e., from another data source like VIIRS), or wind data, we can start to build a statistical inference pipeline that will infer the likelihood of blowing snow or crevasses.\n",
    "\n",
    "### Numeric Considerations\n",
    "\n",
    "Bayesian nets can be used as simple feed forward networks (if we are confident in our priors), or trained to determine the conditional probabilities that the network itself uses. We'll use pymc3 in this notebook to examine how to formulate and build Bayesian networks-- first using discrete variables, then expanding to incorporate continuous variables. All of this falls under ***probabilistic programming***; there a wide variety of tasks that the framework can be used for. The pymc3 library uses Monte Carlo simulations for inference, which can sped up substantially by GPUs-- although they run just fine on CPUs as well.\n",
    "\n",
    "## Defining a Probabilistic Model\n",
    "\n",
    "We'll use a worked epidemiology example to show how to set both static parameters, and distributions over parameters, and then explore adding in empirical observations to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "from IPython.display import Markdown as md"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pymc3 as pm\n",
    "from pymc3 import Normal, Model\n",
    "from pymc3.math import switch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image('./a4Murphy1.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# General model\n",
    "Image('./a4Murphy2.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with Model() as bayesnet:\n",
    "    # Parent\n",
    "    g1_healthy = pm.distributions.discrete.Bernoulli('g1',0.5)\n",
    "    \n",
    "    # Latent variable that conditions the children \n",
    "    g1_L = switch(g1_healthy > 0.5, 0.9, 0.1)\n",
    "    \n",
    "    # Left Child\n",
    "    g2_healthy = pm.distributions.discrete.Bernoulli('g2', g1_L)\n",
    "    # Right Child\n",
    "    g3_healthy = pm.distributions.discrete.Bernoulli('g3', g1_L)\n",
    "    # Conditional switches for blood pressure mean\n",
    "    rate1 = switch(g1_healthy > .5, 50, 60)\n",
    "    rate2 = switch(g2_healthy > .5, 50, 60)\n",
    "    rate3 = switch(g3_healthy > .5, 50, 60)\n",
    "    # Parent and children\n",
    "    x1 = Normal('x1', mu=rate1, sd=10**.5)\n",
    "    x2 = Normal('x2', mu=rate2, sd=10**.5)\n",
    "    x3 = Normal('x3', mu=rate3, sd=10**.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pm.model_graph.model_to_graphviz(model=bayesnet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the defined model is identical to the reference figure, just with **$X_1$** being placed in below rather than above. The *'G'* parameters are all Bernoulli distributions that specify a binary outcome. The *'$X_*$'* variables are continuous distributions of blood pressure."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise 1\n",
    "\n",
    "Compute P(G1 = 2 | X2 = 50). This is the probability that the parent (**$G_1$**) has the heredity disease given that the child (**$X_2$**) has blood pressure of 50."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute P(G1 = 2 | X2 = 50)\n",
    "\n",
    "with Model() as bayesnet:\n",
    "    # Parent\n",
    "    g1_healthy = pm.distributions.discrete.Bernoulli('g1',0.5)\n",
    "    \n",
    "    # Latent variable that conditions the children \n",
    "    g1_L = switch(g1_healthy > 0.5, 0.9, 0.1)\n",
    "    \n",
    "    # Left Child\n",
    "    g2_healthy = pm.distributions.discrete.Bernoulli('g2', g1_L)\n",
    "    # Right Child\n",
    "    g3_healthy = pm.distributions.discrete.Bernoulli('g3', g1_L)\n",
    "    # Conditional switches for blood pressure mean\n",
    "    rate1 = switch(g1_healthy > .5, 50, 60)\n",
    "    rate2 = switch(g2_healthy > .5, 50, 60)\n",
    "    rate3 = switch(g3_healthy > .5, 50, 60)\n",
    "    # Grandkids and child\n",
    "    x1 = Normal('x1', mu=rate1, sd=10**.5)\n",
    "    # Setting x2 as an observed variable\n",
    "    x2 = Normal('x2', mu=rate2, sd=10**.5, observed=50)\n",
    "    x3 = Normal('x3', mu=rate3, sd=10**.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the above is identical to our original model definition-- the only thing that has changed, is that we have set **$X_2$** as an observed variable. There is an analytic solution to this type of inference, but we'll use sampling to estimate an empirical solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with bayesnet:\n",
    "    mytrace = pm.sample(10000, tune=2000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that PyMC3 does the hard work of choosing the appropriate samplers for us :-) The binary variables use Gibbs Metropolis-Hastings, and the continuous use the No-U-Turn-Sampler. The variable **$X_2$** is not sampled, since we set it as observed. The result of the sampling is a *trace*; we can view the trace to get our answer to the original question (although it will take a second to plot given the sample size):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pm.traceplot(mytrace);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The original question was \"*what's the probability that the parent has the disease*\", so we can take the number of samples where **$G_1 = 1$** (is healthy), and divide by the total number of samples to get the percentage of the population in the simulation that was healthy: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_g1_healthy = sum(mytrace['g1'])/len(mytrace['g1'])\n",
    "p_g1_not_healthy = 1 - p_g1_healthy\n",
    "p_g1_healthy, p_g1_not_healthy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "md(\"\"\"We find {}% of the G1 samples are healthy (G1 = 1); so  {}% are unhealthy (G1 = 2). \n",
    "   We can actually access the same information using the `summary()` function, which will also \n",
    "   provide an error estimate:\"\"\".format(p_g1_healthy.round(4)*100,p_g1_not_healthy.round(4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pm.summary(mytrace, round_to=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = pm.summary(mytrace, 'g1', round_to=5)\n",
    "err = round(float(stats.mcse_mean),5)\n",
    "plus = round(float(1-stats['mean'] + stats.mcse_mean),5)\n",
    "minus = round(float(1-stats['mean'] - stats.mcse_mean),5)\n",
    "ans = round(float(1-stats['mean']),5)\n",
    "\n",
    "md(\"\"\"As mentioned, we could analytically compute the answer:\n",
    "\n",
    "- The analytic answer is 0.10535\n",
    "- The MCMC estimate is {} +/- {}\n",
    "\n",
    "Since 0.10535 is between {} and {}, seems that we did well.\"\"\".format(ans,err,plus,minus))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brief digression on MCMC error\n",
    "---------------\n",
    "\n",
    "I was curious about the MCMC error... does the sampler actually know how good a job it's doing? Running the sampler again with a factor less samples certainly gives a worse estimate:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with bayesnet:\n",
    "    mytrace = pm.sample(2000);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_g1_healthy = sum(mytrace['g1'])/len(mytrace['g1'])\n",
    "p_g1_not_healthy = 1 - p_g1_healthy\n",
    "p_g1_healthy, p_g1_not_healthy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pm.summary(mytrace, round_to=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = pm.summary(mytrace, 'g1', round_to=5)\n",
    "err = round(float(stats.mcse_mean),5)\n",
    "plus = round(float(1-stats['mean'] + stats.mcse_mean),5)\n",
    "minus = round(float(1-stats['mean'] - stats.mcse_mean),5)\n",
    "ans = round(float(1-stats['mean']),5)\n",
    "\n",
    "md(\"\"\"The new answer is {} +/- {}. Given that the analytic answer is 0.10535, this \n",
    "   is clearly a worse estimate... but 0.10535 still falls within the {} to {} bounded \n",
    "   error envelope, so we know the answer is less precise. If we need more precision, we can \n",
    "   simply run more samples till the MC error is acceptably low.\"\"\".format(ans,err,minus,plus))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise 2\n",
    "\n",
    "Compute P(X3  = 50 | X2 = 50). That is, what is the conditional probability that child 2 (**$X_3$**) has blood pressure equal to 50 given that their sibling (child 1, **$X_2$**) has normal blood pressure (50). As before, we simply fix the variable of interest using the observed flag:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with Model() as bayesnet:\n",
    "    # Parent\n",
    "    g1_healthy = pm.distributions.discrete.Bernoulli('g1',0.5)\n",
    "    \n",
    "    # Latent variable that conditions the children \n",
    "    g1_L = switch(g1_healthy > 0.5, 0.9, 0.1)\n",
    "    \n",
    "    # Left Child\n",
    "    g2_healthy = pm.distributions.discrete.Bernoulli('g2', g1_L)\n",
    "    # Right Child\n",
    "    g3_healthy = pm.distributions.discrete.Bernoulli('g3', g1_L)\n",
    "    # Conditional switches for blood pressure mean\n",
    "    rate1 = switch(g1_healthy > .5, 50, 60)\n",
    "    rate2 = switch(g2_healthy > .5, 50, 60)\n",
    "    rate3 = switch(g3_healthy > .5, 50, 60)\n",
    "    # Grandkids and child\n",
    "    x1 = Normal('x1', mu=rate1, sd=10**.5)\n",
    "    x2 = Normal('x2', mu=rate2, sd=10**.5)\n",
    "    # Setting x3 as an observed variable\n",
    "    x3 = Normal('x3', mu=rate3, sd=10**.5, observed=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with bayesnet:\n",
    "    mytrace = pm.sample(25000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because the variable that we're interested in is continuous, it isn't as obvious how to answer our question... we could of course cheat, and cast *$X_2$* as discrete value:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n1 = mytrace['x2']>51\n",
    "n2 = mytrace['x2']<50\n",
    "\n",
    "(len(mytrace['x2']) - sum(n1) - sum(n2))/len(mytrace['x2'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The analytic solution is 0.10306, so casting **$X_2$** as a discrete variable gives a close approximation, but is also kinda unsatisfying because:\n",
    "\n",
    "- The width for the approximation isn't always obvious. Here, for blood pressure, using a natural number is a pretty clear fit-- but this isn't as obvious for other cases.\n",
    "\n",
    "- We don't have an error estimation on the approximation. This seems not very Bayesian."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We could also try to cheat an answer by reading **$X_2$** directly off of the trace plots:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pm.traceplot(mytrace);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pm.trace_to_dataframe(mytrace)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.x2.hist(bins=200,density=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.x2.plot.kde()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we've fixed **$X_3$**, we don't see it in the plots, but we do get the **$X_2$** variable that was previously omitted when it was fixed. We can see that the probability for **$X_2$**= 50 is close to our analytic solution, and we don't have to decide a width, but:\n",
    "\n",
    "- We still don't have an error estimate on the parameter\n",
    "- We're just eyeballing a value off of a graph. There's surely a more numerically precise way to do this..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimating a numerically robust solution\n",
    "\n",
    "My preferred solution is to estimate an empirical PDF (Probability Density Function)  via KDE (Kernel Density Estimation) using the trace of **$X_2$** as the input. This allows probability estimation of any value, including decimal values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats.kde import gaussian_kde"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_pdf = gaussian_kde(mytrace['x2'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_pdf(50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This treats the continuous variable as continuous, so there is no need to decide an integration width. It is also numerically precise. For error bounds, we can still get them from the trace summary (using the **$X_2$** `mc_error`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pm.summary(mytrace, round_to=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This works because we have a numerically precise value of the probability, so we can port the error estimates over. We couldn't do this using the first window method because we were looking over a range interval, and didn't have a precise estimate when looking at just the trace graphs.\n",
    "\n",
    "\n",
    "Another thing to note-- because statistics are built of an empirically sampled distribution that is rendered via simulation, we are able to retrieve answers from distributions that are not classic well behaved Gaussian's. The distributions **$X_1, X_2,$** and **$X_3$** are not well behaved; they cannot be approximated by a Gaussian, and it's unclear if they could be even be modeled by a combination of classical statical functions. But, using MCMC estimation, we fundamentally don't care what the output distribution looks like; we can query the trace for probability information regardless of it's form. As more explicit example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll complicate our sampling by adding in more data\n",
    "# pymc3 allows observed variables to be numpy arrays, pandas\n",
    "# dataframes, or theano data structures; so you can add\n",
    "# quite a bit of empirical data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_1 = [38,58,62,55,80,40,51,52]\n",
    "data_2 = [32,110,72,43,66,55,59,52]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with Model() as bayesnet:\n",
    "    # Parent\n",
    "    g1_healthy = pm.distributions.discrete.Bernoulli('g1',0.5)\n",
    "    \n",
    "    # Latent variable that conditions the children \n",
    "    g1_L = switch(g1_healthy > 0.5, 0.9, 0.1)\n",
    "    \n",
    "    # Left Child\n",
    "    g2_healthy = pm.distributions.discrete.Bernoulli('g2', g1_L)\n",
    "    # Right Child\n",
    "    g3_healthy = pm.distributions.discrete.Bernoulli('g3', g1_L)\n",
    "    # Conditional switches for blood pressure mean\n",
    "    rate1 = switch(g1_healthy > .5, 50, 60)\n",
    "    rate2 = switch(g2_healthy > .5, 50, 60)\n",
    "    rate3 = switch(g3_healthy > .5, 50, 60)\n",
    "    # Grandkids and child\n",
    "    x1 = Normal('x1', mu=rate1, sd=10**.5)\n",
    "    x2 = Normal('x2', mu=rate2, sd=10**.5, observed=data_2)\n",
    "    # Setting x3 as an observed variable\n",
    "    x3 = Normal('x3', mu=rate3, sd=10**.5, observed=data_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with bayesnet:\n",
    "    mytrace = pm.sample(5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pm.traceplot(mytrace);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Above tells us that the Bayesian estimate is certain that one child has the condition, and is fairly certain that the other doesn't. This means that the parent is more likely to have the condition than not; also note that the blood pressure likelihoods are highest for each case associated with the condition (compared to a mean)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:mlenv]",
   "language": "python",
   "name": "conda-env-mlenv-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
